Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W - Bag of Tricks for Node Classification with GNN - GAT Normalization #9840

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

liuvince
Copy link

@liuvince liuvince commented Dec 10, 2024

Add normalize parameter to GATConv and GATv2Conv.

Part of #9831 for our final project for the Stanford CS224W course, this allows "GAT with Symmetric Normalized Adjacency Matrix" as described in “Bag of Tricks for Node Classification with Graph Neural Networks”.

Details

  • Implementation of gat_norm inspired from gcn_norm, when edge_index is a SparseTensor, is_torch_sparse_tensor or dense torch Tensor.
  • gat_norm is called after computing the alpha coefficients and return the updated values of edge_index and alpha. The outputs of gat_norm are passed as inputs of self.propagate.
  • Update the docstring of GATConv and GATv2Conv.
  • Add unit test cases.
  • Override the add_self_loops parameter. We remove self loops from the initial graph before calling to gat_norm and add self loops with normalization in gat_norm as described in the paper. We tried to use the tools already provided in the library such as torch_sparse.fill_diag, to_edge_index, add_remaining_self_loops, add_self_loops and to_torch_csr_tensor.
  • One concern is that there is no learned weight regardless of add_self_loops, because we explicitly remove self loops before edge update. This is consistent with the paper's description and gcn_norm, but different from the paper's implementation. Also, it seems that they use both out-degree and in-degree. We would appreciate your feedback on the preferred approach.
  • When is_torch_sparse_tensor(edge_index) == True, we have an issue formatting back the index edge_index and the corresponding values in att_mat to the appropriate format. Our workaround consists of sorting lexicographically the values of att_mat, so it matches the index of edge_index for the propagate and update subsequent steps.
  • When isinstance(edge_index, SparseTensor) and in the case we have multiple heads, e.g. num_heads > 1, we need to perform the operation $D \alpha$, which is a multiplication of SparseTensor (with values dimension greater than 1) and degree matrix. Our solution is based on repeating the degree matrix num_heads times. We don't use repeat_interleave directly as we encounter the following error: "repeat_interleave_cpu" not implemented for 'Float', but the current implementation should follow the same behavior.
  • Only support non-bipartite graph mesasge passing.

Benchmarks

I have the following metrics with one T4 GPU, so it performs better for CiteSeer and PubMed dataset with a computation time cost.

dataset Test Accuracy Test Accuracy (with GAT Norm) Duration Duration (with GAT Norm)
Cora 0.831 ± 0.004 0.825 ± 0.005 4.296s 5.172s
CiteSeer 0.707 ± 0.005 0.715 ± 0.005 4.767s 5.592s
PubMed 0.789 ± 0.003 0.796 ± 0.004 6.603s 7.204s

with the following run commands:

python gat.py --dataset=Cora
python gat.py --dataset=Cora --normalize

python gat.py --dataset=CiteSeer
python gat.py --dataset=CiteSeer --normalize  

python gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --weight_decay=0.001
python gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --weight_decay=0.001 --normalize 

@liuvince liuvince requested review from a team, wsad1, EdisonLeeeee and rusty1s as code owners December 10, 2024 20:54
@liuvince liuvince changed the title Gat normalization CS224W - Bag of Tricks for Node Classification with GNN - GAT Normalization Dec 10, 2024
Copy link

codecov bot commented Dec 11, 2024

Codecov Report

Attention: Patch coverage is 93.06931% with 7 lines in your changes missing coverage. Please review.

Project coverage is 86.36%. Comparing base (1519e9f) to head (ce302c4).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
torch_geometric/typing.py 0.00% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #9840      +/-   ##
==========================================
+ Coverage   85.29%   86.36%   +1.06%     
==========================================
  Files         478      490      +12     
  Lines       31918    32386     +468     
==========================================
+ Hits        27225    27969     +744     
+ Misses       4693     4417     -276     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants